{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第5章 循环神经网络\n",
    "### 5.2 循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0731, -1.4126,  3.0787,  1.3176],\n",
       "        [-1.0251, -0.2845,  1.0140, -0.8578],\n",
       "        [-0.5749, -0.0538, -2.9409,  1.4076]])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "X, W_xh = torch.randn(3, 1), torch.randn(1, 4)\n",
    "H, W_hh = torch.randn(3, 4), torch.randn(4, 4)\n",
    "torch.matmul(X, W_xh)+torch.matmul(H, W_hh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0731, -1.4126,  3.0787,  1.3176],\n",
       "        [-1.0251, -0.2845,  1.0140, -0.8578],\n",
       "        [-0.5749, -0.0538, -2.9409,  1.4076]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.matmul(torch.cat((X, H), dim=1), torch.cat((W_xh, W_hh), dim=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 语言模型数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'想要有直升机\\n想要和你飞到宇宙去\\n想要和你融化在一起\\n融化在宇宙里\\n我每天每天每'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import zipfile\n",
    "with zipfile.ZipFile('data/jaychou_lyrics.txt.zip') as zin:\n",
    "    with zin.open('jaychou_lyrics.txt') as f:\n",
    "        corpus_chars = f.read().decode('utf-8')\n",
    "corpus_chars[:40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus_chars = corpus_chars.replace('\\n', ' ').replace('\\r', ' ')\n",
    "corpus_chars = corpus_chars[:10000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1027"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx_to_char = list(set(corpus_chars))\n",
    "char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])\n",
    "vocab_size = len(char_to_idx)\n",
    "vocab_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chars: 想要有直升机 想要和你飞到宇宙去 想要和\n",
      "indices: [40, 32, 825, 884, 598, 691, 101, 40, 32, 178, 151, 417, 265, 824, 755, 777, 101, 40, 32, 178]\n"
     ]
    }
   ],
   "source": [
    "corpus_indices = [char_to_idx[char] for char in corpus_chars]\n",
    "sample = corpus_indices[:20]\n",
    "print('chars:', ''.join([idx_to_char[idx] for idx in sample]))\n",
    "print('indices:', sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "def data_iter_random(corpus_indices, batch_size, num_steps, device=None):\n",
    "    # 减1是因为输出的索引x是相应输入的索引y加1\n",
    "    # 这里是计算一共生成多少个样本\n",
    "    num_examples = (len(corpus_indices) - 1) // num_steps\n",
    "    # 循环多少次可以遍历所有样本\n",
    "    epoch_size = num_examples // batch_size\n",
    "    # 每一条样本的索引\n",
    "    example_indices = list(range(num_examples))\n",
    "    # 随机打乱样本索引\n",
    "    random.shuffle(example_indices)\n",
    "    # 返回从pos开始的长为num_steps的序列\n",
    "    def _data(pos):\n",
    "        return corpus_indices[pos: pos+num_steps]\n",
    "    if device is None:\n",
    "        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    # 构建每一个epoch内的样本\n",
    "    for i in range(epoch_size):\n",
    "        # 每次读取batch_size个随机样本\n",
    "        i = i * batch_size\n",
    "        batch_indices = example_indices[i: i+batch_size]\n",
    "        X = [_data(j*num_steps) for j in batch_indices]\n",
    "        Y = [_data(j*num_steps+1) for j in batch_indices]\n",
    "        yield torch.tensor(X, dtype=torch.float32, device=device), torch.tensor(Y, dtype=torch.float32, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X:  tensor([[ 6.,  7.,  8.,  9., 10., 11.],\n",
      "        [18., 19., 20., 21., 22., 23.]], device='cuda:0') \n",
      "Y: tensor([[ 7.,  8.,  9., 10., 11., 12.],\n",
      "        [19., 20., 21., 22., 23., 24.]], device='cuda:0') \n",
      "\n",
      "X:  tensor([[12., 13., 14., 15., 16., 17.],\n",
      "        [ 0.,  1.,  2.,  3.,  4.,  5.]], device='cuda:0') \n",
      "Y: tensor([[13., 14., 15., 16., 17., 18.],\n",
      "        [ 1.,  2.,  3.,  4.,  5.,  6.]], device='cuda:0') \n",
      "\n"
     ]
    }
   ],
   "source": [
    "my_seq = list(range(30))\n",
    "for X, Y in data_iter_random(my_seq, batch_size=2, num_steps=6):\n",
    "    print('X: ', X, '\\nY:', Y, '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_iter_consecutive(corpus_indices, batch_size, num_steps, device=None):\n",
    "    if device is None:\n",
    "        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    corpus_indices = torch.tensor(corpus_indices, dtype=torch.float32, device=device)\n",
    "    # 原始数据长度\n",
    "    data_len = len(corpus_indices)\n",
    "    batch_len = data_len // batch_size\n",
    "    indices = corpus_indices[0: batch_size*batch_len].view(batch_size, batch_len)\n",
    "    # 减1是因为输出的索引x是相应输入的索引y加1\n",
    "    epoch_size = (batch_len - 1) // num_steps\n",
    "    for i in range(epoch_size):\n",
    "        i = i * num_steps\n",
    "        X = indices[:, i: i + num_steps]\n",
    "        Y = indices[:, i + 1: i + num_steps + 1]\n",
    "        yield X, Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X:  tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],\n",
      "        [15., 16., 17., 18., 19., 20.]], device='cuda:0') \n",
      "Y: tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],\n",
      "        [16., 17., 18., 19., 20., 21.]], device='cuda:0') \n",
      "\n",
      "X:  tensor([[ 6.,  7.,  8.,  9., 10., 11.],\n",
      "        [21., 22., 23., 24., 25., 26.]], device='cuda:0') \n",
      "Y: tensor([[ 7.,  8.,  9., 10., 11., 12.],\n",
      "        [22., 23., 24., 25., 26., 27.]], device='cuda:0') \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for X, Y in data_iter_consecutive(my_seq, batch_size=2, num_steps=6):\n",
    "    print('X: ', X, '\\nY:', Y, '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.4 循环神经网络的从零开始实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 0, 0,  ..., 0, 0, 0],\n",
       "        [0, 0, 1,  ..., 0, 0, 0]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "F.one_hot(torch.tensor([0, 2]), vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, torch.Size([2, 1027]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def to_onehot(X, size):\n",
    "    return F.one_hot(X.t(), size)\n",
    "X = torch.arange(10).view(2, 5)\n",
    "inputs = to_onehot(X, vocab_size)\n",
    "len(inputs), inputs[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "will use cuda\n"
     ]
    }
   ],
   "source": [
    "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size\n",
    "print('will use', device)\n",
    "import numpy as np\n",
    "def get_params():\n",
    "    def _one(shape):\n",
    "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n",
    "        return torch.nn.Parameter(ts, requires_grad=True)\n",
    "    # 隐藏层参数\n",
    "    W_xh = _one((num_inputs, num_hiddens))\n",
    "    W_hh = _one((num_hiddens, num_hiddens))\n",
    "    b_h = torch.nn.Parameter(torch.zeros(num_hiddens, device=device, requires_grad=True))\n",
    "    # 输出层参数\n",
    "    W_hq = _one((num_hiddens, num_outputs))\n",
    "    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, requires_grad=True))\n",
    "    return torch.nn.ParameterList([W_xh, W_hh, b_h, W_hq, b_q])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_rnn_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device), )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rnn(inputs, state, params):\n",
    "    # inputs和outputs皆为num_steps个形状为(batch_size, vocab_size)的矩阵\n",
    "    W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    # 上一层传来的隐藏状态\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        H = torch.tanh(torch.matmul(X.float(), W_xh) + torch.matmul(H, W_hh) + b_h)\n",
    "        Y = torch.matmul(H, W_hq) + b_q\n",
    "        outputs.append(Y)\n",
    "    return outputs, (H, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 torch.Size([2, 1027]) torch.Size([2, 256])\n"
     ]
    }
   ],
   "source": [
    "state = init_rnn_state(X.shape[0], num_hiddens, device)\n",
    "inputs = to_onehot(X.to(device), vocab_size)\n",
    "params = get_params()\n",
    "outputs, state_new = rnn(inputs, state, params)\n",
    "print(len(outputs), outputs[0].shape, state_new[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_rnn(prefix, num_chars, rnn, params, init_rnn_state, num_hiddens, vocab_size, device, idx_to_char, char_to_idx):\n",
    "    # 初始化的隐藏状态\n",
    "    state = init_rnn_state(1, num_hiddens, device)\n",
    "    # 将输入的首字符传入到输出序列中\n",
    "    output = [char_to_idx[prefix[0]]]\n",
    "    for t in range(num_chars + len(prefix) - 1):\n",
    "        # 将上一时间步的输出作为当前时间步的输入\n",
    "        X = to_onehot(torch.tensor([[output[-1]]], device=device), vocab_size)\n",
    "        # 计算输出和更新隐藏状态\n",
    "        (Y, state) = rnn(X, state, params)\n",
    "        # 下一个时间步的输入是prefix里的字符或者当前的最佳预测字符\n",
    "        if t < len(prefix)-1:\n",
    "            output.append(char_to_idx[prefix[t+1]])\n",
    "        else:\n",
    "            output.append(int(Y[0].argmax(dim=1).item()))\n",
    "    return ''.join([idx_to_char[i] for i in output])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'分开短因三狠在够枪猎馆宙'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_rnn('分开', 10, rnn, params, init_rnn_state, num_hiddens, vocab_size, device, idx_to_char, char_to_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_clipping(params, theta, device):\n",
    "    # 存储参数梯度值的平方和开根号\n",
    "    norm = torch.tensor([0.0], device=device)\n",
    "    for param in params:\n",
    "        norm += (param.grad.data**2).sum()\n",
    "    norm = norm.sqrt().item()\n",
    "    if norm > theta:\n",
    "        for param in params:\n",
    "            param.grad.data *= (theta/norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import math\n",
    "def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, is_random_iter, num_epochs, num_steps, lr, clipping_theta, \n",
    "                         batch_size, pred_period, pred_len, prefixes):\n",
    "    if is_random_iter:\n",
    "        # 随机采样\n",
    "        data_iter_fn = data_iter_random\n",
    "    else:\n",
    "        # 相邻采样\n",
    "        data_iter_fn = data_iter_consecutive\n",
    "    # 初始化模型参数\n",
    "    params = get_params()\n",
    "    # 交叉熵损失\n",
    "    loss = torch.nn.CrossEntropyLoss()\n",
    "    for epoch in range(num_epochs):\n",
    "        if not is_random_iter:\n",
    "            # 如使用相邻采样，在epoch开始时初始化隐藏状态\n",
    "            state = init_rnn_state(batch_size, num_hiddens, device)\n",
    "        # 损失之和、样本数、起始时间\n",
    "        l_sum, n, start = 0.0, 0, time.time()\n",
    "        # 生成训练样本及label\n",
    "        data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device)\n",
    "        for X, Y in data_iter:\n",
    "            # 如使用随机采样，在每个小批量更新前初始化隐藏状态\n",
    "            if is_random_iter:\n",
    "                state = init_rnn_state(batch_size, num_hiddens, device)\n",
    "            else:\n",
    "                # 否则需要使用detach函数从计算图分离隐藏状态, 这是为了\n",
    "                # 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)\n",
    "                for s in state:\n",
    "                    s.detach()\n",
    "            # 独热编码\n",
    "            inputs = to_onehot(X.long(), vocab_size)\n",
    "            # outputs有num_steps个形状为(batch_size, vocab_size)的矩阵\n",
    "            (outputs, state) = rnn(inputs, state, params)\n",
    "            # 拼接之后形状为(num_steps * batch_size, vocab_size)\n",
    "            outputs = torch.cat(outputs, dim=0)\n",
    "            # Y的形状是(batch_size, num_steps)，转置后再变成长度为\n",
    "            # batch * num_steps 的向量，这样跟输出的行一一对应\n",
    "            y = torch.transpose(Y, 0, 1).contiguous().view(-1)\n",
    "            # 使用交叉熵损失计算平均分类误差\n",
    "            l = loss(outputs, y.long())\n",
    "            # 梯度清0\n",
    "            if params[0].grad is not None:\n",
    "                for param in params:\n",
    "                    param.grad.data.zero_()\n",
    "            l.backward(retain_graph=True)\n",
    "            # 裁剪梯度\n",
    "            grad_clipping(params, clipping_theta, device)\n",
    "            # 因为误差已经取过均值，梯度不用再做平均\n",
    "            d2l.sgd(params, lr, 1)\n",
    "            l_sum += l.item() * y.shape[0]\n",
    "            n += y.shape[0]\n",
    "        if (epoch + 1) % pred_period == 0:\n",
    "            print('epoch %d, perplexity %f, time %.2f sec' % (epoch+1, math.exp(l_sum/n), time.time()-start))\n",
    "            for prefix in prefixes:\n",
    "                print(' -', predict_rnn(prefix, pred_len, rnn, params, init_rnn_state, num_hiddens, vocab_size, device, idx_to_char, char_to_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs, num_steps, batch_size, lr, clipping_theta = 250, 35, 32, 1e2, 1e-2\n",
    "pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 50, perplexity 68.997638, time 0.26 sec\n",
      " - 分开 我想你这想你 不知我 别你我 三子我 三你我 三你我 三你我 三你我 三你我 三你我 三你我 三你\n",
      " - 不分开 快我不 你想我 三你我 三你的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女\n",
      "epoch 100, perplexity 9.674701, time 0.26 sec\n",
      " - 分开 一直在被 在谁村 我 再这我 别弄是 我给我这我不要再难怎 深不和 旧颗的 不颗走 一九四颗三颗四\n",
      " - 不分开堡 我不能再想 我不 我不 我不要再想 我不 我不 我不要再想 我不 我不 我不要再想 我不 我不 \n",
      "epoch 150, perplexity 2.918735, time 0.26 sec\n",
      " - 分开 娘子心碎心人 哼哼哈兮 快使用双截棍 哼哼哈兮 如使用双截棍 哼哼哈兮 如使我有轻功 飞檐走壁 为\n",
      " - 不分开扫 我后你爸 你打我妈 这样跟吗 全对出纵 恨敢有从 没有没空 你在操空 你不懂纵 不敢心从 说人止\n",
      "epoch 200, perplexity 1.590260, time 0.26 sec\n",
      " - 分开 娘子在用 不人之容 说一场梦 不敢不同 你不懂 连一句珍重 也有苦衷的字笑 回都就待了作来 有一些\n",
      " - 不分开扫 我叫你爸 你打我妈 这样对吗干嘛这样 何必让酒 我对儿带恼力 我知得好国力 不知不觉 你已经离开\n",
      "epoch 250, perplexity 1.321758, time 0.26 sec\n",
      " - 分开 一只到 心给我 娘分 是我去回头 布 在人胸打我只无 难发  静有我烦球有恼要 我不我这辈子注定一\n",
      " - 不分开期把的胖女巫 用拉丁文念咒语啦啦呜 她养的黑猫笑起来像哭 啦啦啦呜 水风我早已经猜透 透是我 别怪我\n"
     ]
    }
   ],
   "source": [
    "train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,\n",
    "                      vocab_size, device, corpus_indices, idx_to_char,\n",
    "                      char_to_idx, True, num_epochs, num_steps, lr,\n",
    "                      clipping_theta, batch_size, pred_period, pred_len,\n",
    "                      prefixes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 50, perplexity 61.255773, time 0.55 sec\n",
      " - 分开 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我\n",
      " - 不分开 你不了有 你谁了空 在谁在人 在谁的让 在果的让 在果的可 在人的可 在人的可 在人的可 在人的可\n",
      "epoch 100, perplexity 6.411222, time 0.54 sec\n",
      " - 分开 我不要再样 我不懂再 在我村外的溪边河口默默等著我 娘子依旧每日折一枝杨柳 你 那着 在小两外的溪\n",
      " - 不分开柳 你已经双  让有梦 如有箱中停留 你有在 瞎透了 什么了有 沙头一碗热粥 配上几斤的牛肉 我说店\n",
      "epoch 150, perplexity 1.914138, time 0.55 sec\n",
      " - 分开 我不懂 爱么我 三过了明 在故心中的溪边 默默等待 娘子 一什么酒 再来一碗热粥 配上几斤的牛肉 \n",
      " - 不分开柳 你已经离著的 内知哈 娘是她人在江南有我 泪不休 语沉默娘子她人在江南等我 爱不休 语沉默 一壶\n",
      "epoch 200, perplexity 1.674830, time 0.55 sec\n",
      " - 分开 她候的 旧时光 是属于那年 老师坊 旧皮箱 是属于那年 老师坊 旧皮箱 是属于 传满于中年的白墙 \n",
      " - 不分开觉 你已经离开你听抱 别会的钟我撒娇 从你睡著一直到老 就是开不了口让她知道 就是那么简单几句 我办\n",
      "epoch 250, perplexity 6.651694, time 0.54 sec\n",
      " - 分开 你已在 是时么 一九么那年每天 一壶现 一什么 一颗堂中旧一 还说骷依一中妈在 用著你 是给眼的风\n",
      " - 不分开柳的美 一小走受是满下的暴息 你不定的想动 然黄里 娘点已的我不于人的温远 夕阳前可了酒下在天忆 你\n"
     ]
    }
   ],
   "source": [
    "train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,\n",
    "                      vocab_size, device, corpus_indices, idx_to_char,\n",
    "                      char_to_idx, False, num_epochs, num_steps, lr,\n",
    "                      clipping_theta, batch_size, pred_period, pred_len,\n",
    "                      prefixes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.5 循环神经网络的简洁实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "num_hiddens = 256\n",
    "rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 2, 256])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state = torch.zeros((1, 2, num_hiddens))\n",
    "state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 2, 256]) 1 torch.Size([2, 256])\n"
     ]
    }
   ],
   "source": [
    "num_steps = 32\n",
    "batch_size = 2\n",
    "X = torch.rand(num_steps, batch_size, vocab_size)\n",
    "Y, state_new = rnn_layer(X, state)\n",
    "print(Y.shape, len(state_new), state_new[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNModel(nn.Module):\n",
    "    def __init__(self, rnn_layer, vocab_size):\n",
    "        super(RNNModel, self).__init__()\n",
    "        self.rnn = rnn_layer\n",
    "        # 若为双向循环网络需要×2\n",
    "        self.hidden_size = rnn_layer.hidden_size*(2 if rnn_layer.bidirectional else 1)\n",
    "        self.vocab_size = vocab_size\n",
    "        self.dense = nn.Linear(self.hidden_size, vocab_size)\n",
    "        self.state = None\n",
    "    # inputs: (batch, seq_len)\n",
    "    def forward(self, inputs, state):\n",
    "        # 获取one-hot向量表示\n",
    "        # X是个list\n",
    "        X = d2l.to_onehot(inputs.long(), self.vocab_size)\n",
    "        Y, self.state = self.rnn(X.float(), state)\n",
    "        # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens)\n",
    "        # 它的输出形状为(num_steps * batch_size, vocab_size)\n",
    "        output = self.dense(Y.view(-1, Y.shape[-1]))\n",
    "        return output, self.state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char, char_to_idx):\n",
    "    state = None\n",
    "    # output会记录prefix加上输出\n",
    "    # 将输入的首字符传入到输出序列中\n",
    "    output = [char_to_idx[prefix[0]]]\n",
    "    for t in range(num_chars+len(prefix)-1):\n",
    "        # 将上一时间步的输出作为当前时间步的输入\n",
    "        X = torch.tensor([output[-1]], device=device).view(1, 1)\n",
    "        if state is not None:\n",
    "            if isinstance(state, tuple):\n",
    "                state = (state[0].to(device), state[1].to(device))\n",
    "            else:\n",
    "                state = state.to(device)\n",
    "        (Y, state) = model(X, state)\n",
    "        if t < len(prefix) - 1:\n",
    "            output.append(char_to_idx[prefix[t+1]])\n",
    "        else:\n",
    "            output.append(int(Y.argmax(dim=1).item()))\n",
    "    return ''.join([idx_to_char[i] for i in output])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'分开歌斑音觉涌鹿田斑去用'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = RNNModel(rnn_layer, vocab_size).to(device)\n",
    "predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, num_epochs, num_steps, lr, clipping_theta, \n",
    "                                  batch_size, pred_period, pred_len, prefixes):\n",
    "    # 交叉熵损失\n",
    "    loss = torch.nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "    model.to(device)\n",
    "    state = None\n",
    "    for epoch in range(num_epochs):\n",
    "        # 损失之和、样本数、起始时间\n",
    "        l_sum, n, start = 0.0, 0, time.time()\n",
    "        # 相邻采样\n",
    "        data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device)\n",
    "        for X, Y in data_iter:\n",
    "            if state is not None:\n",
    "                # 否则需要使用detach函数从计算图分离隐藏状态, 这是为了\n",
    "                # 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)\n",
    "                if isinstance(state, tuple):\n",
    "                    state = (state[0].detach(), state[1].detach())\n",
    "                else:\n",
    "                    state = state.detach()\n",
    "            # output: 形状为(num_steps * batch_size, vocab_size)\n",
    "            (output, state) = model(X, state)\n",
    "            # Y的形状是(batch_size, num_steps)，转置后再变成长度为\n",
    "            # batch * num_steps 的向量，这样跟输出的行一一对应\n",
    "            y = torch.transpose(Y, 0, 1).contiguous().view(-1)\n",
    "            # 使用交叉熵损失计算平均分类误差\n",
    "            l = loss(output, y.long())\n",
    "            # 梯度清0\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # https://blog.csdn.net/qq_29695701/article/details/89965986\n",
    "            l.backward()\n",
    "            # 裁剪梯度\n",
    "            d2l.grad_clipping(model.parameters(), clipping_theta, device)\n",
    "            optimizer.step()\n",
    "            l_sum += l.item() * y.shape[0]\n",
    "            n += y.shape[0]\n",
    "        try:\n",
    "            perplexity = math.exp(l_sum / n)\n",
    "        except OverflowError:\n",
    "            perplexity = float('inf')\n",
    "        if (epoch + 1) % pred_period == 0:\n",
    "            print('epoch %d, perplexity %f, time %.2f sec' % (epoch+1, perplexity, time.time()-start))\n",
    "            for prefix in prefixes:\n",
    "                print(' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char, char_to_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 50, perplexity 1.152438, time 0.07 sec\n",
      " - 分开的我想就像 不要 我 在进黑色幽默 爱你的明的手不放开 爱可不可以简简单单没有伤害 你 靠着我的肩膀\n",
      " - 不分开不要活 也有苦笑 没有你在的爸爸一句带酒 败给你的黑色幽默 不想太多 我想 我的肩膀 你 在我胸口睡\n",
      "epoch 100, perplexity 1.146996, time 0.07 sec\n",
      " - 分开的我后 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈\n",
      " - 不分开不美边 这样的梦 出现裂缝 隐隐作痛 随便个战 弓箭 太多的话去对医药箱说 别怪我 别怪我 三分球 \n",
      "epoch 150, perplexity 1.026819, time 0.07 sec\n",
      " - 分开的我有我这辈子注定一个人演戏 最后再一个人慢慢的回忆 没有了过去 我将往事抽离 如果我遇见你是一场悲\n",
      " - 不分开不要再这样的节奏 谁都无可奈何 没有你以后 我灵魂失控 黑云在降落 我被它拖着走 静静悄悄默默离开 \n",
      "epoch 200, perplexity 1.022801, time 0.07 sec\n",
      " - 分开的我有我爱你 爱情来的太快就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不\n",
      " - 不分开不知天空 分手不这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的肩膀 你 在我胸口睡著\n",
      "epoch 250, perplexity 1.024344, time 0.07 sec\n",
      " - 分开的我后悔着对不起 一枚铜币 悲伤得很隐密 它在许愿池里轻轻叹息 太多的我爱你 让它喘不过气 已经 失\n",
      " - 不分开 我有一双翅膀 二双翅膀 随时出发 偷偷出发 我一定带我妈走  从前的教育别人的家庭  别人的爸爸种\n"
     ]
    }
   ],
   "source": [
    "num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-2, 1e-2 # 注意这里的学习率设置\n",
    "pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']\n",
    "train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,\n",
    "                            corpus_indices, idx_to_char, char_to_idx,\n",
    "                            num_epochs, num_steps, lr, clipping_theta,\n",
    "                            batch_size, pred_period, pred_len, prefixes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.7 GRU-门控循环单元"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "will use cuda\n"
     ]
    }
   ],
   "source": [
    "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size\n",
    "print('will use', device)\n",
    "def get_params():\n",
    "    def _one(shape):\n",
    "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n",
    "        return torch.nn.Parameter(ts, requires_grad=True)\n",
    "    def _three():\n",
    "        return (_one((num_inputs, num_hiddens)), \n",
    "               _one((num_hiddens, num_hiddens)), \n",
    "               torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))\n",
    "    # 更新门参数\n",
    "    W_xz, W_hz, b_z = _three()\n",
    "    # 重置门参数\n",
    "    W_xr, W_hr, b_r = _three()\n",
    "    # 候选隐藏状态参数\n",
    "    W_xh, W_hh, b_h = _three()\n",
    "    # 输出层参数\n",
    "    W_hq = _one((num_hiddens, num_outputs))\n",
    "    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)\n",
    "    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_gru_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device), )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gru(inputs, state, params):\n",
    "    # 参数\n",
    "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    # 初始化的隐藏状态\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        Z = torch.sigmoid(torch.matmul(X.float(), W_xz) + torch.matmul(H, W_hz) + b_z)\n",
    "        R = torch.sigmoid(torch.matmul(X.float(), W_xr) + torch.matmul(H, W_hr) + b_r)\n",
    "        H_tilda = torch.tanh(torch.matmul(X.float(), W_xh) + torch.matmul(R*H, W_hh) + b_h)\n",
    "        H = Z*H + (1-Z)*H_tilda\n",
    "        Y = torch.matmul(H, W_hq) + b_q\n",
    "        outputs.append(Y)\n",
    "    return outputs, (H,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2\n",
    "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 40, perplexity 149.917021, time 2.04 sec\n",
      " - 分开 我想你 我想你 你不的我 你不的让我 你想我 你不我 你不的我 你不的让我 你想我 你不我 你不的\n",
      " - 不分开 我想你 我想你 你不的我 你不的让我 你想我 你不我 你不的我 你不的让我 你想我 你不我 你不的\n",
      "epoch 80, perplexity 32.286560, time 2.07 sec\n",
      " - 分开 我想要这样 我不要再想 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我\n",
      " - 不分开  我有你这样我 一场好觉 你爱我 别你的手 在人放人 温爱在人人  爱我有你的微笑 像龙放口 让我\n",
      "epoch 120, perplexity 5.518456, time 2.04 sec\n",
      " - 分开 我想要你的微笑每天都能看离  我知道这里很美但家乡的你更美 我想要这样坦堡 像这样的生色 让我开红\n",
      " - 不分开 我已能这样打我妈妈 我不能再想  不知不觉 我不要再想 我不能再想 我不 我不 我不能 爱情走的太\n",
      "epoch 160, perplexity 1.665924, time 2.02 sec\n",
      " - 分开 小作    所有那里对着我进攻       古巴忆对王颁布了汉摩拉比比典轻轻的脸窗  古录像一个世\n",
      " - 不分开 你是一直 是你的那笑 喜的风美主义 平伤你的手 一阵莫名感动 我以带你 回我的外婆家 一起看着日落\n"
     ]
    }
   ],
   "source": [
    "d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,\n",
    "                      vocab_size, device, corpus_indices, idx_to_char,\n",
    "                      char_to_idx, False, num_epochs, num_steps, lr,\n",
    "                      clipping_theta, batch_size, pred_period, pred_len,\n",
    "                      prefixes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 40, perplexity 1.016270, time 0.07 sec\n",
      " - 分开始乡情怯的我 相思寄红豆 相思寄红豆无能为力的在人海中漂泊心伤透 娘子她人在江南等我 泪不休 语沉默\n",
      " - 不分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让\n",
      "epoch 80, perplexity 1.009380, time 0.07 sec\n",
      " - 分开始乡情怯的我 相思寄红豆 相思寄红豆无能为力的在人海中漂泊心伤透 娘子她人在江南等我 泪不休 语沉默\n",
      " - 不分开始打呼 管家是一只会说法语举止优雅的猪 吸血前会念约翰福音做为弥补 拥有一双蓝色眼睛的凯萨琳公主 专\n",
      "epoch 120, perplexity 1.009627, time 0.07 sec\n",
      " - 分开始乡还为分手不投 又不会掩护我 选你这种队友 瞎透了我 说你说 分数怎么停留 一直在停留 谁让它停留\n",
      " - 不分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让\n",
      "epoch 160, perplexity 1.011384, time 0.06 sec\n",
      " - 分开始乡还是你的脑袋有问题 随便说说 其实我早已经猜透看透不想多说 只是我怕眼泪撑不住 不懂 你的黑色幽\n",
      " - 不分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让\n"
     ]
    }
   ],
   "source": [
    "lr = 1e-2\n",
    "gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)\n",
    "model = d2l.RNNModel(gru_layer, vocab_size).to(device)\n",
    "d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,\n",
    "                            corpus_indices, idx_to_char, char_to_idx,\n",
    "                            num_epochs, num_steps, lr, clipping_theta,\n",
    "                            batch_size, pred_period, pred_len, prefixes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.8 LSTM-长短期记忆"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "will use cuda\n"
     ]
    }
   ],
   "source": [
    "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size\n",
    "print('will use', device)\n",
    "def get_params():\n",
    "    def _one(shape):\n",
    "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n",
    "        return torch.nn.Parameter(ts, requires_grad=True)\n",
    "    def _three():\n",
    "        return (_one((num_inputs, num_hiddens)), \n",
    "               _one((num_hiddens, num_hiddens)), \n",
    "               torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))\n",
    "    # 输入门参数\n",
    "    W_xi, W_hi, b_i = _three()\n",
    "    # 遗忘门参数\n",
    "    W_xf, W_hf, b_f = _three()\n",
    "    # 输出门参数\n",
    "    W_xo, W_ho, b_o = _three()\n",
    "    # 候选记忆细胞参数\n",
    "    W_xc, W_hc, b_c = _three()\n",
    "    # 输出层参数\n",
    "    W_hq = _one((num_hiddens, num_outputs))\n",
    "    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)\n",
    "    return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_lstm_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device), \n",
    "           torch.zeros((batch_size, num_hiddens), device=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lstm(inputs, state, params):\n",
    "    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params\n",
    "    (H, C) = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        I = torch.sigmoid(torch.matmul(X.float(), W_xi) + torch.matmul(H, W_hi) + b_i)\n",
    "        F = torch.sigmoid(torch.matmul(X.float(), W_xf) + torch.matmul(H, W_hf) + b_f)\n",
    "        O = torch.sigmoid(torch.matmul(X.float(), W_xo) + torch.matmul(H, W_ho) + b_o)\n",
    "        C_tilda = torch.tanh(torch.matmul(X.float(), W_xc) + torch.matmul(H, W_hc) + b_c)\n",
    "        C = F * C + I * C_tilda\n",
    "        H = O * C.tanh()\n",
    "        Y = torch.matmul(H, W_hq) + b_q\n",
    "        outputs.append(Y)\n",
    "    return outputs, (H, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2\n",
    "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 40, perplexity 213.158646, time 2.46 sec\n",
      " - 分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我\n",
      " - 不分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我\n",
      "epoch 80, perplexity 70.030242, time 2.48 sec\n",
      " - 分开 我想你你的你 我想想你 你不了 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我\n",
      " - 不分开 我想你你 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 \n",
      "epoch 120, perplexity 17.365369, time 2.38 sec\n",
      " - 分开 我说你的你笑我 你样的风  我有你的爱你 你元人 深埋在美索不达米亚 我想要你 爱样我 想样  没\n",
      " - 不分开 我知你这想堡堡  想想你的你笑着 像么  有有你的肩着 我想想的你  我 我想你的你膀  没 在你\n",
      "epoch 160, perplexity 4.445391, time 2.43 sec\n",
      " - 分开 我说带的话笑 后悔在对不起 景色入秋 漫天了空凉我 塞北的客栈人多 牧草有没有 我马儿有些瘦 天涯\n",
      " - 不分开 我已你这样笑着听 想是开 说说 一颗两颗 在上四碗 在一上空的在老 还上上看看 让家 在什人 我开\n"
     ]
    }
   ],
   "source": [
    "d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,\n",
    "                      vocab_size, device, corpus_indices, idx_to_char,\n",
    "                      char_to_idx, False, num_epochs, num_steps, lr,\n",
    "                      clipping_theta, batch_size, pred_period, pred_len,\n",
    "                      prefixes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 40, perplexity 1.020968, time 0.08 sec\n",
      " - 分开的太快就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情走的太快就像\n",
      " - 不分开 我有轻功 飞檐走壁 为人耿直不屈 一身正气 他们儿子我习惯 从小就耳濡目染 什么刀枪跟棍棒 我都耍\n",
      "epoch 80, perplexity 1.014793, time 0.07 sec\n",
      " - 分开的玩笑 想通 却又再考倒我 说散 你想很久了吧? 败给你的黑色幽默 说散 你想很久了吧? 我的认真败\n",
      " - 不分开 我有多看着你 有话去对医药箱说 别怪我 别怪我 说你怎么面对我 甩开球我满腔的怒火 我想揍你已经很\n",
      "epoch 120, perplexity 1.012057, time 0.07 sec\n",
      " - 分开的玩笑 想通 却又再考倒我 说散 你想很久了吧? 败给你的黑色幽默 说散 你想很久了吧? 我的认真败\n",
      " - 不分开 我有了空 是你完美演出的一场戏 宁愿心碎哭泣 再狠狠忘记 你爱过我的证据 让晶莹的泪滴 闪烁成回忆\n",
      "epoch 160, perplexity 1.008711, time 0.07 sec\n",
      " - 分开的快乐是你 想你想的都会笑 没有你在 我有多难熬  没有你在我有多难熬多烦恼  没有你烦 我有多烦恼\n",
      " - 不分开 我有轻重 就是开不了口让她知道 就是那么简单几句 我办不到 整颗心悬在半空 我只能够远远看著 这些\n"
     ]
    }
   ],
   "source": [
    "lr = 1e-2\n",
    "lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)\n",
    "model = d2l.RNNModel(lstm_layer, vocab_size)\n",
    "d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,\n",
    "                            corpus_indices, idx_to_char, char_to_idx,\n",
    "                            num_epochs, num_steps, lr, clipping_theta,\n",
    "                            batch_size, pred_period, pred_len, prefixes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
